# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

"""Discriminator architectures from the paper
"Efficient Geometry-aware 3D Generative Adversarial Networks"."""

import numpy as np
import torch
import dnnlib
from torch_utils import persistence
from torch_utils.ops import upfirdn2d
from models.networks_stylegan2 import DiscriminatorBlock, MappingNetwork, DiscriminatorEpilogue


@persistence.persistent_class
class SingleDiscriminator(torch.nn.Module):
    def __init__(self,
                 c_dim,                  # Conditioning label (C) dimensionality.
                 img_resolution,         # Input resolution.
                 img_channels,           # Number of input color channels.
                 architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
                 channel_base=32768,     # Overall multiplier for the number of channels.
                 channel_max=512,        # Maximum number of channels in any layer.
                 num_fp16_res=4,         # Use FP16 for the N highest resolutions.
                 conv_clamp=256,         # Clamp the output of convolution layers to +-X, None = disable clamping.
                 cmap_dim=None,          # Dimensionality of mapped conditioning label, None = default.
                 sr_upsample_factor=1,   # Ignored for SingleDiscriminator
                 block_kwargs={},        # Arguments for DiscriminatorBlock.
                 mapping_kwargs={},      # Arguments for MappingNetwork.
                 epilogue_kwargs={},     # Arguments for DiscriminatorEpilogue.
                 ):
        super().__init__()
        self.c_dim = c_dim
        self.img_resolution = img_resolution
        self.img_resolution_log2 = int(np.log2(img_resolution))
        self.img_channels = img_channels
        self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
        channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
        fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)

        if cmap_dim is None:
            cmap_dim = channels_dict[4]
        if c_dim == 0:
            cmap_dim = 0

        common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
        cur_layer_idx = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res] if res < img_resolution else 0
            tmp_channels = channels_dict[res]
            out_channels = channels_dict[res // 2]
            use_fp16 = (res >= fp16_resolution)
            block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
                                       first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
            setattr(self, f'b{res}', block)
            cur_layer_idx += block.num_layers
        if c_dim > 0:
            self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
        self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)

    def forward(self, img, c, update_emas=False, **block_kwargs):
        img = img['image']

        _ = update_emas  # unused
        x = None
        for res in self.block_resolutions:
            block = getattr(self, f'b{res}')
            x, img = block(x, img, **block_kwargs)

        cmap = None
        if self.c_dim > 0:
            cmap = self.mapping(None, c)
        x = self.b4(x, img, cmap)
        return x

    def extra_repr(self):
        return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'

# ----------------------------------------------------------------------------


def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'):
    if filter_mode == 'antialiased':
        ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True)
    elif filter_mode == 'classic':
        ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2)
        ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False)
        ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1)
    elif filter_mode == 'nearest':
        ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='nearest')
    elif filter_mode == 'none':
        ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
    elif type(filter_mode) == float:
        assert 0 < filter_mode < 1

        filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True)
        aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=False)
        ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered

    return ada_filtered_64

# ----------------------------------------------------------------------------


@persistence.persistent_class
class DualDiscriminator(torch.nn.Module):
    def __init__(self,
                 c_dim,                  # Conditioning label (C) dimensionality.
                 img_resolution,         # Input resolution.
                 img_channels,           # Number of input color channels.
                 aux_img_type,           # Type of auxiliary input images.
                 architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
                 channel_base=32768,     # Overall multiplier for the number of channels.
                 channel_max=512,        # Maximum number of channels in any layer.
                 num_fp16_res=4,         # Use FP16 for the N highest resolutions.
                 conv_clamp=256,         # Clamp the output of convolution layers to +-X, None = disable clamping.
                 cmap_dim=None,          # Dimensionality of mapped conditioning label, None = default.
                 disc_c_noise=0,         # Corrupt camera parameters with X std dev of noise before disc. pose conditioning.
                 block_kwargs={},        # Arguments for DiscriminatorBlock.
                 mapping_kwargs={},      # Arguments for MappingNetwork.
                 epilogue_kwargs={},     # Arguments for DiscriminatorEpilogue.
                 ):
        super().__init__()

        self.c_dim = c_dim
        self.img_resolution = img_resolution
        self.img_resolution_log2 = int(np.log2(img_resolution))
        self.aux_img_type = aux_img_type
        if aux_img_type == 'mask':
            self.img_channels = img_channels + 1
        elif aux_img_type == 'raw':
            self.img_channels = img_channels * 2
        else:
            raise NotImplementedError(f'Unsupported input images `{aux_img_type}` encountered in DualDiscriminator.')
        self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
        channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
        fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)

        if cmap_dim is None:
            cmap_dim = channels_dict[4]
        if c_dim == 0:
            cmap_dim = 0

        common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp)
        cur_layer_idx = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res] if res < img_resolution else 0
            tmp_channels = channels_dict[res]
            out_channels = channels_dict[res // 2]
            use_fp16 = (res >= fp16_resolution)
            block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
                                       first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
            setattr(self, f'b{res}', block)
            cur_layer_idx += block.num_layers
        if c_dim > 0:
            self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
        epilogue_kwargs.update(in_channels=channels_dict[4], cmap_dim=cmap_dim, resolution=4, **common_kwargs)
        self.b4 = dnnlib.util.construct_class_by_name(**epilogue_kwargs)
        self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1]))
        self.disc_c_noise = disc_c_noise

    def forward(self, img, c, update_emas=False, **block_kwargs):
        if self.aux_img_type == 'mask':
            image_2nd = img['image_mask'] * 2 - 1  # normalize to [-1, 1]
        elif self.aux_img_type == 'raw':
            image_2nd = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter, filter_mode='nearest')
        img = torch.cat([img['image'], image_2nd], 1)

        _ = update_emas  # unused
        x = None
        for res in self.block_resolutions:
            block = getattr(self, f'b{res}')
            x, img = block(x, img, **block_kwargs)

        cmap = None
        if self.c_dim > 0:
            if self.disc_c_noise > 0:
                c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
            cmap = self.mapping(None, c)
        x = self.b4(x, img, cmap)
        return x

    def extra_repr(self):
        return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'

# ----------------------------------------------------------------------------


@persistence.persistent_class
class DummyDualDiscriminator(torch.nn.Module):
    def __init__(self,
                 c_dim,                  # Conditioning label (C) dimensionality.
                 img_resolution,         # Input resolution.
                 img_channels,           # Number of input color channels.
                 architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
                 channel_base=32768,     # Overall multiplier for the number of channels.
                 channel_max=512,        # Maximum number of channels in any layer.
                 num_fp16_res=4,         # Use FP16 for the N highest resolutions.
                 conv_clamp=256,         # Clamp the output of convolution layers to +-X, None = disable clamping.
                 cmap_dim=None,          # Dimensionality of mapped conditioning label, None = default.
                 block_kwargs={},        # Arguments for DiscriminatorBlock.
                 mapping_kwargs={},      # Arguments for MappingNetwork.
                 epilogue_kwargs={},     # Arguments for DiscriminatorEpilogue.
                 ):
        super().__init__()
        img_channels *= 2

        self.c_dim = c_dim
        self.img_resolution = img_resolution
        self.img_resolution_log2 = int(np.log2(img_resolution))
        self.img_channels = img_channels
        self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
        channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
        fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)

        if cmap_dim is None:
            cmap_dim = channels_dict[4]
        if c_dim == 0:
            cmap_dim = 0

        common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
        cur_layer_idx = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res] if res < img_resolution else 0
            tmp_channels = channels_dict[res]
            out_channels = channels_dict[res // 2]
            use_fp16 = (res >= fp16_resolution)
            block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
                                       first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
            setattr(self, f'b{res}', block)
            cur_layer_idx += block.num_layers
        if c_dim > 0:
            self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
        self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
        self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1]))

        self.raw_fade = 1

    def forward(self, img, c, update_emas=False, **block_kwargs):
        self.raw_fade = max(0, self.raw_fade - 1 / (500000 / 32))

        image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade
        img = torch.cat([img['image'], image_raw], 1)

        _ = update_emas  # unused
        x = None
        for res in self.block_resolutions:
            block = getattr(self, f'b{res}')
            x, img = block(x, img, **block_kwargs)

        cmap = None
        if self.c_dim > 0:
            cmap = self.mapping(None, c)
        x = self.b4(x, img, cmap)
        return x

    def extra_repr(self):
        return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'

# ----------------------------------------------------------------------------


@persistence.persistent_class
class TriDiscriminator(torch.nn.Module):
    def __init__(self,
                 c_dim,                  # Conditioning label (C) dimensionality.
                 img_resolution,         # Input resolution.
                 img_channels,           # Number of input color channels.
                 architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
                 channel_base=32768,     # Overall multiplier for the number of channels.
                 channel_max=512,        # Maximum number of channels in any layer.
                 num_fp16_res=4,         # Use FP16 for the N highest resolutions.
                 conv_clamp=256,         # Clamp the output of convolution layers to +-X, None = disable clamping.
                 cmap_dim=None,          # Dimensionality of mapped conditioning label, None = default.
                 disc_c_noise=0,         # Corrupt camera parameters with X std dev of noise before disc. pose conditioning.
                 block_kwargs={},        # Arguments for DiscriminatorBlock.
                 mapping_kwargs={},      # Arguments for MappingNetwork.
                 epilogue_kwargs={},     # Arguments for DiscriminatorEpilogue.
                 ):
        super().__init__()
        img_channels = 2 * img_channels + 1

        self.c_dim = c_dim
        self.img_resolution = img_resolution
        self.img_resolution_log2 = int(np.log2(img_resolution))
        self.img_channels = img_channels
        self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
        channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
        fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)

        if cmap_dim is None:
            cmap_dim = channels_dict[4]
        if c_dim == 0:
            cmap_dim = 0

        common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
        cur_layer_idx = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res] if res < img_resolution else 0
            tmp_channels = channels_dict[res]
            out_channels = channels_dict[res // 2]
            use_fp16 = (res >= fp16_resolution)
            block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
                                       first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
            setattr(self, f'b{res}', block)
            cur_layer_idx += block.num_layers
        if c_dim > 0:
            self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
        epilogue_kwargs.update(in_channels=channels_dict[4], cmap_dim=cmap_dim, resolution=4, **common_kwargs)
        self.b4 = dnnlib.util.construct_class_by_name(**epilogue_kwargs)
        self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1]))
        self.disc_c_noise = disc_c_noise

    def forward(self, img, c, update_emas=False, **block_kwargs):
        image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter, filter_mode='nearest')
        image_mask = img['image_mask'] * 2 - 1  # normalize to [-1, 1]
        img = torch.cat([img['image'], image_mask, image_raw], 1)

        _ = update_emas  # unused
        x = None
        for res in self.block_resolutions:
            block = getattr(self, f'b{res}')
            x, img = block(x, img, **block_kwargs)

        cmap = None
        if self.c_dim > 0:
            if self.disc_c_noise > 0:
                c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
            cmap = self.mapping(None, c)
        x = self.b4(x, img, cmap)
        return x

    def extra_repr(self):
        return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'

# ----------------------------------------------------------------------------


@persistence.persistent_class
class DiscriminatorSuperRes(torch.nn.Module):
    def __init__(self,
                 raw_resolution,         # Raw image resolution.
                 img_resolution,         # Input resolution.
                 img_channels,           # Number of input color channels.
                 out_resolution=None,    # Output resolution.
                 architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
                 channel_base=32768,     # Overall multiplier for the number of channels.
                 channel_max=512,        # Maximum number of channels in any layer.
                 num_fp16_res=4,         # Use FP16 for the N highest resolutions.
                 conv_clamp=256,         # Clamp the output of convolution layers to +-X, None = disable clamping.
                 cmap_dim=None,          # Dimensionality of mapped conditioning label, None = default.
                 disc_c_noise=0,         # Corrupt camera parameters with X std dev of noise before disc. pose conditioning.
                 block_kwargs={},        # Arguments for DiscriminatorBlock.
                 mapping_kwargs={},      # Ignored.
                 epilogue_kwargs={},     # Arguments for DiscriminatorEpilogue.
                 ):
        super().__init__()

        self.raw_resolution = raw_resolution
        self.img_resolution = img_resolution
        self.out_resolution = out_resolution if out_resolution is not None else 4
        self.raw_resolution_log2 = int(np.log2(self.raw_resolution))
        self.img_resolution_log2 = int(np.log2(self.img_resolution))
        self.out_resolution_log2 = int(np.log2(self.out_resolution))
        self.img_channels = img_channels
        self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, self.out_resolution_log2, -1)]
        self.c_resolutions = [2 ** i for i in range(self.raw_resolution_log2, self.out_resolution_log2, -1)]
        channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [self.out_resolution]}
        fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)

        if cmap_dim is None:
            cmap_dim = channels_dict[self.out_resolution]

        common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
        cur_layer_idx = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res] if res < img_resolution else 0
            tmp_channels = channels_dict[res]
            out_channels = channels_dict[res // 2]
            use_fp16 = (res >= fp16_resolution)
            block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
                                       first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
            setattr(self, f'b{res}', block)
            cur_layer_idx += block.num_layers
        for res in self.c_resolutions:
            in_channels = channels_dict[res] if res < raw_resolution else 0
            tmp_channels = channels_dict[res]
            out_channels = channels_dict[res // 2]
            use_fp16 = (res >= fp16_resolution)
            block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
                                       first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
            setattr(self, f'cb{res}', block)
            cur_layer_idx += block.num_layers
        epilogue_kwargs.update(in_channels=channels_dict[self.out_resolution], cmap_dim=cmap_dim, resolution=self.out_resolution, **common_kwargs)
        self.bout = dnnlib.util.construct_class_by_name(**epilogue_kwargs)
        self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1]))
        self.disc_c_noise = disc_c_noise

    def forward(self, img, c, update_emas=False, **block_kwargs):
        raw = img['image_raw']
        img = img['image']

        _ = update_emas  # unused
        x = None
        for res in self.block_resolutions:
            block = getattr(self, f'b{res}')
            x, img = block(x, img, **block_kwargs)
        cmap = None
        for res in self.c_resolutions:
            block = getattr(self, f'cb{res}')
            cmap, img = block(cmap, raw, **block_kwargs)

        x = self.bout(x, img, cmap)
        return x

    def extra_repr(self):
        return f'raw_resolution={self.raw_resolution:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'

# ----------------------------------------------------------------------------
